#!/usr/bin/env python3
# -*- coding: utf-8 -*-
####################################################################################################
# Copyright (C) by the DBCSR developers group - All rights reserved                                #
# This file is part of the DBCSR library.                                                          #
#                                                                                                  #
# For information on the license, see the LICENSE file.                                            #
# For further information please visit https://dbcsr.cp2k.org                                      #
# SPDX-License-Identifier: GPL-2.0+                                                                #
####################################################################################################

import json
import argparse
from pathlib import Path

from kernels.smm_acc import params_dict_to_kernel, gpu_architectures


# ===============================================================================
def main(gpu_version: str, base_dir: Path):
    param_fn = base_dir / f"parameters_{gpu_version}.json"

    try:  # Read existing parameters
        with param_fn.open("r") as fhandle:
            print(f"GPU version: {gpu_version}")
            all_kernels = [
                params_dict_to_kernel(**params) for params in json.load(fhandle)
            ]
        print(f"About to process {len(all_kernels):,} kernels from file {param_fn}")
    except:  # noqa: E722
        all_kernels = []
        pass

    try:  # Read GPU properties (warp size)
        gpu_props_fn = base_dir / "../kernels/gpu_properties.json"
        arch_code = gpu_architectures[param_fn.name]
        with gpu_props_fn.open("r") as fhandle:
            gpu_warp_size = json.load(fhandle)[arch_code]["Threads_/_Warp"]
    except:  # noqa: E722
        gpu_warp_size = 32
        pass

    print(f"GPU warp size: {gpu_warp_size}")

    # Construct output
    out = write_parameters_file(all_kernels, gpu_warp_size)

    # Write to c++ header-file
    file_h = "parameters.h"
    if all_kernels:
        print(f"Found {len(all_kernels):,} kernels in file {param_fn}")
    print(f"Printing them to file {file_h}")
    with open(file_h, "w") as f:
        f.write(out)


# ===============================================================================
def write_parameters_file(all_pars, gpu_warp_size):
    # Header
    out = """\
/*------------------------------------------------------------------------------------------------*
 * Copyright (C) by the DBCSR developers group - All rights reserved                              *
 * This file is part of the DBCSR library.                                                        *
 *                                                                                                *
 * For information on the license, see the LICENSE file.                                          *
 * For further information please visit https://dbcsr.cp2k.org                                    *
 * SPDX-License-Identifier: GPL-2.0+                                                              *
 *------------------------------------------------------------------------------------------------*/

/*****************************************************************************
 *  FILE GENERATED BY SCRIPT 'generate_parameters.py' DO NOT EDIT            *
 *****************************************************************************/

#ifndef PARAMETERS_H
#define PARAMETERS_H

#include "parameters_utils.h"

/*
 * Lookup table: given a triplet (m, n, k) describing a matrix-matrix multiplication,
 * look up its optimal kernel parameters
 *
 * Keys:
 *   (m, n, k)
 *
 * Values: array of 8 integers with elements:
 *   0: mm algorithm (enum defined in libsmm_acc.h, possible values: 1, 2, 3, 4, 5)
 *   1: tile_m
 *   2: tile_n
 *   3: w
 *   4: v
 *   5: threads
 *   6: grouping
 *   7: minblocks
 *
 * Note: for the matrix matrix multiplication algorithms which take less than 8 parameters (i.e. tiny, small, medium),
 * the superfluous parameters are set to 0
 */

"""

    # Warp size
    out += f"extern const int warp_size = {gpu_warp_size};\n\n"

    # Map of kernel parameters
    out += """\
extern const std::unordered_map<Triplet, KernelParameters> ht = {
"""
    # Initializer list body
    print("Get parameters and write to file")
    init_list_line = (
        "    {{ {{{{{m:3}, {n:3}, {k:3}}}}},"
        + " {{{{ {algorithm:1}, {tile_m:2}, {tile_n:2}, {w:2}, {v:2}, {threads:3}, {grouping:2}, {minblocks:2} }}}} }},"
        + "  // perf: {perf} {source}\n"
    )
    for pars in all_pars:
        out += init_list_line.format(**pars.as_dict_for_parameters_h)

    # Footer
    out += """\
};

#endif
//EOF
"""

    return out


# ===============================================================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generator of libsmm_acc. The Library for Small Matrix Multiplications on GPU.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-g",
        "--gpu_version",
        metavar="GPU_VERSION",
        default="P100",
        help="GPU card version, used to select the appropriate libsmm_acc parameters file. Default: %(default)s",
    )
    parser.add_argument(
        "-d",
        "--base_dir",
        metavar="BASE_DIR",
        default="parameters/",
        type=Path,
        help="Set the base directory to look for the parameter files. Default: %(default)s",
    )
    args = parser.parse_args()
    main(args.gpu_version, args.base_dir)
